#include "MNISTDataLoader.h"

data::MNISTDataLoader::MNISTDataLoader(bool train, const std::string &pos) : core(
        train ? std::move(torch::data::make_data_loader(
                torch::data::datasets::MNIST(pos,
                                             torch::data::datasets::MNIST::Mode::kTrain).map(
                        torch::data::transforms::Stack<>()),
                /*batch_size=*/64)) : std::move(torch::data::make_data_loader(
                torch::data::datasets::MNIST(pos,
                                             torch::data::datasets::MNIST::Mode::kTest).map(
                        torch::data::transforms::Stack<>()),
                /*batch_size=*/64))
) {}